import torch, numpy
from matplotlib import pyplot as plt

def accuracy(net, testloader):
    correct = 0
    total = 0
    device = next(net.parameters()).device
    # since we're not training, we don't need to calculate the gradients for our outputs
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            # calculate outputs by running images through the network
            outputs,_ = net(images)
            # the class with the highest energy is what we choose as prediction
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    if total == 0:
        return -1.0

    print(f'Accuracy of the network: {100 * correct // total} %')
    return 100 * correct / total

'''
    Generation evaluation
'''

import torch
import itertools
import torchvision

flatten = lambda x: list(itertools.chain.from_iterable(x))

import torch
import kornia.metrics as metrics
from tqdm.auto import tqdm

def get_ssim_pairs_kornia(x, y):
    return metrics.ssim(x, y, window_size=3).reshape(x.shape[0], -1).mean(dim=1)


def get_ssim_all(x, y):
    ssims = []
    for i in tqdm(range(y.shape[0])):
        scores = get_ssim_pairs_kornia(x, y[i:i + 1].expand(x.shape[0], -1, -1, -1))
        ssims.append(scores)

    return torch.stack(ssims).t()


def normalize_batch(x, ret_all=False):
    """ Normalize each element in batch x --> (x-mean)/std"""
    n, c, h, w = x.shape
    means = x.reshape(n * c, h * w).mean(dim=1).reshape(n, c, 1, 1) # 每个样本，每个channel的平均值
    stds = x.reshape(n * c, h * w).std(dim=1).reshape(n, c, 1, 1) # 每个样本，每个channel的标准差
    if ret_all:
        return x.sub(means).div(stds), means, stds
    else:
        return x.sub(means).div(stds)


def l2_dist(x, y, div_dim=False):
    """ L2 distance between x and y """
    x = x.reshape(x.shape[0], -1)
    y = y.reshape(y.shape[0], -1)

    xx = x.pow(2).sum(1).view(-1, 1) # column vector
    yy = y.pow(2).sum(1).view(1, -1) # row vector
    xy = torch.einsum('id,jd->ij', x, y)
    dists = xx + yy - 2 * xy

    if div_dim:
        N, D = x.shape
        dists /= D

    return dists


def ncc_dist(x, y, div_dim=False):
    """ Normalized Cross-Correlation distacne between x and y """
    return l2_dist(normalize_batch(x), normalize_batch(y), div_dim)


def transform_vmin_vmax_batch(x, min_max=None):
    """ Transform each image in x: [min, max] --> [0, 1]"""
    if min_max is None:
        vmin = x.data.reshape(x.shape[0], -1).min(dim=1)[0][:, None, None, None] # 每一张图 3 个channels 的最大最小值
        vmax = x.data.reshape(x.shape[0], -1).max(dim=1)[0][:, None, None, None]
    else:
        vmin, vmax = min_max
    return (x - vmin).div(vmax - vmin)


def viz_nns(x, y, max_per_nn=None, metric='ncc', ret_all=False):
    """
    return a batch, for each image in x, its nn in y
    sorted according to closest nn
    metric: NCC
    max_per_nn: filter duplicates (leave only max_per_nn elements of y-elements)
    """

    if metric == 'ncc':
        dists = ncc_dist(x, y)
    elif metric == 'l2':
        dists = l2_dist(x, y)
    else:
        raise ValueError(f'Unknown metric={metric}')

    v, nn_idx = dists.min(dim=1) # 跟每个x距离最近的y

    keep = None
    if max_per_nn is not None:
        nn_idx_vals_i = torch.stack([nn_idx, v, torch.arange(v.shape[0], device=v.device)])
        nn_idx_vals_i = [(int(a), b, int(c)) for a, b, c in nn_idx_vals_i.t().tolist()]  # bring indexes back to int
        sorted_stuff = sorted(nn_idx_vals_i)

        # filter duplicates (leave only max_per_nn from each image from y)
        counter = 0
        cur_idx = sorted_stuff[0][0]
        keep = []
        for e in sorted_stuff:
            if e[0] != cur_idx:
                cur_idx = e[0]
                counter = 0
            if counter < max_per_nn:
                keep.append(e)
                counter += 1
        # sort by best value first
        keep = sorted(keep, key=lambda q: q[1])
        # keep is now: (nn idx in y, value, idx of x)
        xx = x[torch.tensor([e[2] for e in keep])]
        yy = y[torch.tensor([e[0] for e in keep])]
        v = torch.tensor([e[1] for e in keep])
    else:
        _, sidxs = v.sort()
        xx = x[sidxs]
        yy = y[nn_idx[sidxs]]

    qq = torch.stack(flatten(list(zip(xx, yy))))
    qq = transform_vmin_vmax_batch(qq)

    if ret_all:
        return qq, v, xx, yy, keep

    return qq, v # 跟每个x最近的y 以及他们之间的距离


def get_evaluation_score_dssim(xxx, yyy, ds_mean, vote=None, show=False):
    xxx = xxx.clone()
    yyy = yyy.clone()

    x2search = torch.nn.functional.interpolate(xxx, scale_factor=1 / 2, mode='bicubic')
    y2search = torch.nn.functional.interpolate(yyy, scale_factor=1 / 2, mode='bicubic')
    D = ncc_dist(y2search, x2search, div_dim=True)

    dists, idxs = D.sort(dim=1, descending=False)

    if vote is not None:
        # Ignore distant nearest-neighbours
        xs_idxs = []
        for i in range(dists.shape[0]):
            x_idxs = [idxs[i, 0].item()]
            for j in range(1, dists.shape[1]):
                if (dists[i, j] / dists[i, 0]) < 1.1:
                    x_idxs.append(idxs[i, j].item())
                else:
                    break
            xs_idxs.append(x_idxs)

        # Voting
        xs = []
        for x_idxs in xs_idxs:
            if vote == 'min':
                x_voted = xxx[x_idxs[0]].unsqueeze(0)
            elif vote == 'mean':
                x_voted = xxx[x_idxs].mean(dim=0, keepdim=True)
            elif vote == 'median':
                x_voted = xxx[x_idxs].median(dim=0, keepdim=True).values
            elif vote == 'mode':
                x_voted = xxx[x_idxs].mode(dim=0, keepdim=True).values
            else:
                raise
            xs.append(x_voted)
        xx = torch.cat(xs, dim=0).clone()
        yy = yyy
    else:
        xx = xxx[idxs[:, 0]]
        yy = yyy

    # Scale to images
    yy += ds_mean
    xx = transform_vmin_vmax_batch(xx + ds_mean)

    # Score
    ssims = get_ssim_pairs_kornia(xx, yy)
    dssim = (1 - ssims) / 2
    dssims, sort_idxs = dssim.sort(descending=False)

    # Sort & Show
    xx = xx[sort_idxs]
    yy = yy[sort_idxs]

    qq = torch.stack(flatten(list(zip(xx, yy))))
    grid = torchvision.utils.make_grid(qq[:100], normalize=False, nrow=20)

    if show:
        import matplotlib.pyplot as plt
        plt.figure(figsize=(80 * 2, 10 * 2))
        plt.imshow(grid.permute(1, 2, 0).cpu().numpy())

    ev_score = dssims[:10].mean()
    return ev_score.item(), grid # 平均相似度以及，跟y最近的x构成的grid


def similar_visualization(x, y, image_path):

    x,y = x.detach().cpu(),y.detach().cpu()    
    x = x.clone()
    y = y.clone()

    x2search = torch.nn.functional.interpolate(x, scale_factor=1 / 2, mode='bicubic')
    y2search = torch.nn.functional.interpolate(y, scale_factor=1 / 2, mode='bicubic')
    D = ncc_dist(y2search, x2search, div_dim=True)

    dists, idxs = D.sort(dim=1, descending=False)

    sorted_x = x[idxs[:, 0]]
    sorted_y = y

    # Scale to images
    sorted_y = transform_vmin_vmax_batch(sorted_y)
    sorted_x = transform_vmin_vmax_batch(sorted_x)

    # Score
    ssims = get_ssim_pairs_kornia(sorted_x, sorted_y)
    dssim = (1 - ssims) / 2
    dssims, sort_idxs = dssim.sort(descending=False)

    # Sort & Show
    sorted_x = sorted_x[sort_idxs]
    sorted_y = sorted_y[sort_idxs]

    qq = torch.stack(flatten(list(zip(sorted_x, sorted_y))))
    torchvision.utils.save_image(qq[:64], image_path, nrow=8)
    grid = torchvision.utils.make_grid(qq[:64], nrow=8)

    ev_score = dssims[:10].mean()
    print("ev_score", ev_score.item())
    return ev_score.item(), grid

def visualization(x, image_path):
    x = transform_vmin_vmax_batch(x)
    torchvision.utils.save_image(x[:64], image_path, nrow=8)
    grid = torchvision.utils.make_grid(x[:64], nrow=8)
    return grid

def scatter(f, a, x, y, colors, markers):
    for i in range(3):
        a.scatter(x[y == i,0], x[y == i,1], s = 200, marker=markers[i], color=colors[i])

    a.set_frame_on(False)
    a.set_xticks([])
    a.set_yticks([])
    return f,a

def visualization_2D(x, y, training_input, training_label, visualization_path):

    colors = {
        0:"#de4649",
        1:"#6aa35c",
        2:"#3c73a8",
    }

    training_markers = {
        0:"x",
        1:"x",
        2:"x",
    }

    generated_markers = {
        0:".",
        1:".",
        2:".",
    }

    f,a = plt.subplots(1,1,figsize = (10,10))   
    f,a = scatter(f, a, training_input.detach().cpu(), training_label.detach().cpu(), colors, training_markers)
    f,a = scatter(f, a, x.detach().cpu(), y.detach().cpu(), colors, generated_markers)
    f.savefig(visualization_path,  bbox_inches='tight')
    return f, a

def image_from_torch_to_numpy(img_tensor):
    img_np = img_tensor.cpu().numpy().transpose(1, 2, 0)
    img_np = numpy.uint8(img_np * 255)